Back to Article
Nickel (Ni) Agriculture site RF_Regression
Download Source

Nickel (Ni) Agriculture site RF_Regression

Author

Alex Koiter

Load libraries

In [1]:
suppressPackageStartupMessages({
  library(tidyverse)
  library(randomForest)
  library(terra)
  library(caret)
  library(patchwork)
  library(sf)
})
Warning: package 'ggplot2' was built under R version 4.4.3
Warning: package 'randomForest' was built under R version 4.4.2
Warning: package 'terra' was built under R version 4.4.3
Warning: package 'caret' was built under R version 4.4.2
Warning: package 'patchwork' was built under R version 4.4.2

Load data

In [2]:
 attribute <- c("plan_curvature", "profile_curvature", "saga_wetness_index", "catchment_area", "relative_slope_position", "channel_network_distance", "elevation")

data <- read_csv(here::here("./notebooks/ag_terrain_data.csv"), show_col_types = FALSE) %>%
  select("x", "y", "ni", any_of(attribute))

Map soil property

In [3]:
temp_rast <- rast(data)
crs(x = temp_rast, warn=FALSE) <- "epsg:26914"

coords <- read_csv(here::here("./notebooks/coords.csv"), show_col_types = FALSE) %>% 
  st_as_sf(coords = c("long", "lat"),  crs = 4326) %>%
  st_transform(crs = 26914)

p1<- ggplot() +
  tidyterra::geom_spatraster(data = temp_rast, aes(fill = ni)) +
  scale_fill_viridis_c(name = "Ni", breaks = seq(4, 9, 1)) +
  geom_sf(data = filter(coords, site == "Agriculture")) +
  theme_bw(base_size = 12) +
  theme(axis.title = element_blank(),
        axis.text = element_blank(),
        axis.ticks = element_blank(),
        legend.position = "bottom") +
  scale_y_continuous(expand = c(0,0)) +
  scale_x_continuous(expand = c(0,0)) +
  ggspatial::annotation_scale(location = "bl") +
  ggspatial::annotation_north_arrow(location = "br")

p2 <- ggplot() +
  tidyterra::geom_spatraster(data = temp_rast, aes(fill = elevation)) +
  scale_fill_viridis_c(name = "Elevation (m)", option = "inferno") +
  geom_sf(data = filter(coords, site == "Agriculture")) +
  theme_bw(base_size = 12) +
  theme(axis.title = element_blank(),
        axis.text = element_blank(),
        axis.ticks = element_blank(),
        legend.position = "bottom") +
  scale_y_continuous(expand = c(0,0)) +
  scale_x_continuous(expand = c(0,0)) +
  ggspatial::annotation_scale(location = "bl") +
  ggspatial::annotation_north_arrow(location = "br")
In [4]:
#| fig-width: 12
#| fig-asp: 0.5
p1+p2  

Create training, validation and testing datasets

60 % of the data in the training data set 20 % of the data in the validation data set 20 % of the data in the testing data set

In [5]:
set.seed(123) # makes it reproducible

temp2 <- data %>%
  mutate(dataset = sample(c("train", "validation", "test"), size = nrow(.), replace = TRUE, prob = c(0.6, 0.2, 0.2)))

train <- temp2 %>% 
  filter(dataset == "train" | dataset == "validation") 

validation <- temp2 %>% 
  filter(dataset == "validation") %>%
  select(-dataset)

test <- temp2 %>% 
  filter(dataset == "test") %>%
  select(-dataset)

Feature selection

vifstep() calculates VIF for all variables, excludes the one with the highest VIF (if it is greater than the threshold), repeat the procedure until no variables with a VIF greater than th remains.

In [6]:
features <- usdm::vifstep(rast(select(filter(train, dataset == "train"), -dataset, -ni)), th = 8)
features
No variable from the 7 input variables has collinearity problem. 

The linear correlation coefficients ranges between: 
min correlation ( elevation ~ profile_curvature ):  0.002419244 
max correlation ( elevation ~ saga_wetness_index ):  -0.8977745 

---------- VIFs of the remained variables -------- 
                 Variables      VIF
1           plan_curvature 1.350442
2        profile_curvature 1.315861
3       saga_wetness_index 5.372984
4           catchment_area 1.240758
5  relative_slope_position 2.262207
6 channel_network_distance 1.811908
7                elevation 5.545539

Remove correlated features from the data sets

In [7]:
train <- temp2 %>% 
  filter(dataset == "train" | dataset == "validation") %>%
  usdm::exclude(features) %>%
  select(-x, -y)
Warning in usdm::exclude(., features): No variable to exclude!
validation <- temp2 %>% 
  filter(dataset == "validation") %>%
  select(-dataset) %>%
  usdm::exclude(features) %>%
  select(-x, -y)
Warning in usdm::exclude(., features): No variable to exclude!
test <- temp2 %>% 
  filter(dataset == "test") %>%
  select(-dataset) %>%
  usdm::exclude(features) %>%
  select(-x, -y)
Warning in usdm::exclude(., features): No variable to exclude!

Tune the training RF model using the validation dataset

Instructions (https://stackoverflow.com/questions/18155482/how-to-specify-a-validation-holdout-set-to-caret) This uses the caret package and I included the validation set inside my training set and just define the resampling measures to only use the validation data. This step is to optimize the mtry parameter

In [8]:
tc <- trainControl(method = "cv", number = 1, index = list(Fold1 = which(train$dataset == "train")), savePredictions = T)

set.seed(456)
validate.rf <- train(ni ~ ., data = select(train, -dataset), method = "rf", trControl = tc)
plot(validate.rf)

my_mtry <- validate.rf$finalModel$mtry
my_mtry
[1] 7

Validation back test

Uses the validation dataset as the test

In [9]:
set.seed(789)

rf.fit <- randomForest(ni ~ ., data = select(filter(train, dataset == "train"), -dataset), 
                       ntree = 500, keep.forest = TRUE, importance = TRUE, mtry = my_mtry,
                       ytest = select(filter(train, dataset == "validation"), -dataset)$ni,
                       xtest = dplyr::select(select(filter(train, dataset == "validation"), -dataset), -ni))
rf.fit

Call:
 randomForest(formula = ni ~ ., data = select(filter(train, dataset ==      "train"), -dataset), ntree = 500, keep.forest = TRUE, importance = TRUE,      mtry = my_mtry, ytest = select(filter(train, dataset == "validation"),          -dataset)$ni, xtest = dplyr::select(select(filter(train,          dataset == "validation"), -dataset), -ni)) 
               Type of random forest: regression
                     Number of trees: 500
No. of variables tried at each split: 7

          Mean of squared residuals: 0.3382089
                    % Var explained: 93.12
                       Test set MSE: 0.33
                    % Var explained: 93.69
In [10]:
oob_val <- sqrt(rf.fit$mse)
test_val <- sqrt(rf.fit$test$mse)

val_plot <- tibble(`Out of Bag Error` = oob_val,
                   Validation = test_val,
                   ntrees = 1:rf.fit$ntree) %>%
  pivot_longer(cols = -ntrees, names_to = "Metric", values_to = "RMSE" )

ggplot(data = val_plot, aes(ntrees, RMSE, color = Metric)) +
  geom_line() +
  theme_bw() +
  xlab("Number of trees")

Testing the RF model

Uses the RF model to predict the test dataset. We compare predicted against actual

In [11]:
prediction <- predict(rf.fit, newdata = test)
test_plot <- data.frame(pred = prediction, obs = test$ni)

r_sq <- summary(lm(pred~obs, data = test_plot))
r_sq

Call:
lm(formula = pred ~ obs, data = test_plot)

Residuals:
     Min       1Q   Median       3Q      Max 
-2.34093 -0.27905 -0.05581  0.20593  3.04222 

Coefficients:
            Estimate Std. Error t value Pr(>|t|)    
(Intercept) 2.416376   0.274444   8.805   <2e-16 ***
obs         0.919600   0.009176 100.218   <2e-16 ***
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Residual standard error: 0.5572 on 771 degrees of freedom
Multiple R-squared:  0.9287,    Adjusted R-squared:  0.9286 
F-statistic: 1.004e+04 on 1 and 771 DF,  p-value: < 2.2e-16
ggplot(data = test_plot, aes(y = pred, x = obs)) +
  geom_point() +
  ggpmisc::stat_poly_line() +
  ggpmisc::stat_poly_eq() +
  theme_bw() +
  geom_abline(slope = 1, intercept = 0) +
  coord_fixed(ratio = 1)
Registered S3 methods overwritten by 'ggpp':
  method                  from   
  heightDetails.titleGrob ggplot2
  widthDetails.titleGrob  ggplot2

mse_test <- test_plot %>%
  summarise(mse = mean((obs - pred)^2))

Importance

In [12]:
ImpData <- as.data.frame(importance(rf.fit)) %>%
  mutate(Var.Names = row.names(.)) %>%
 `row.names<-`(as.character(1:nrow(importance(rf.fit))))

Top based on IncNodePurity

In [13]:
slice_max(ImpData,order_by = IncNodePurity, n = 10)
    %IncMSE IncNodePurity                Var.Names
7 357.00744    9503.59795                elevation
3 156.19691     956.46701       saga_wetness_index
6 106.37618     391.69942 channel_network_distance
5  48.44527     201.98926  relative_slope_position
2  30.28675     155.10546        profile_curvature
4  34.86870     145.61204           catchment_area
1  19.60779      98.83987           plan_curvature
In [14]:
p1 <- ggplot(slice_max(ImpData,order_by = IncNodePurity, n = 10), aes(x = fct_reorder(Var.Names, IncNodePurity), y = IncNodePurity)) +
  geom_segment(aes(x = fct_reorder(Var.Names, IncNodePurity), xend = fct_reorder(Var.Names, IncNodePurity), y = 0, yend = IncNodePurity), color="skyblue") +
  geom_point(aes(size = `%IncMSE`), color = "blue", alpha = 0.6) +
  theme_bw() +
  coord_flip() +
  labs(x = "Terrain Attribute") +
  theme(legend.position="bottom")

Top 10 based on %IncMSE

In [15]:
slice_max(ImpData,order_by = `%IncMSE`, n = 10)
    %IncMSE IncNodePurity                Var.Names
7 357.00744    9503.59795                elevation
3 156.19691     956.46701       saga_wetness_index
6 106.37618     391.69942 channel_network_distance
5  48.44527     201.98926  relative_slope_position
4  34.86870     145.61204           catchment_area
2  30.28675     155.10546        profile_curvature
1  19.60779      98.83987           plan_curvature
In [16]:
p2<- ggplot(slice_max(ImpData,order_by = `%IncMSE`, n = 10), aes(x = fct_reorder(Var.Names, `%IncMSE`), y = `%IncMSE`)) +
  geom_segment(aes(x = fct_reorder(Var.Names, `%IncMSE`), xend = fct_reorder(Var.Names, `%IncMSE`), y = 0, yend = `%IncMSE`), color="skyblue") +
  geom_point(aes(size = IncNodePurity), color = "blue", alpha = 0.6) +
  theme_bw() +
  coord_flip() +
  labs(x = "Terrain Attribute") +
  theme(legend.position="bottom")

Combined

In [17]:
#| label: test
#| fig-cap: Ag Ni 
p1+p2

Ag Ni

Ag Ni

Write data

In [18]:
importance_data <- ImpData %>%
  mutate(MSE_rank = rank(-`%IncMSE`)) %>%
  mutate(Purity_rank = rank(-IncNodePurity)) %>%
  mutate(site = "Agriculture",
        property = "Ni")

if(file.exists(here::here("./notebooks/importance_data.csv"))) {
  importance_data_final <- read_csv(here::here("./notebooks/importance_data.csv")) |>
    rows_upsert(importance_data, by = c("MSE_rank", "site", "property")) 
} else importance_data_final <- importance_data
Rows: 168 Columns: 7
── Column specification ────────────────────────────────────────────────────────
Delimiter: ","
chr (3): Var.Names, site, property
dbl (4): %IncMSE, IncNodePurity, MSE_rank, Purity_rank

ℹ Use `spec()` to retrieve the full column specification for this data.
ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
write_csv(importance_data_final, here::here("./notebooks/importance_data.csv"))


model_performance <- tibble(MSE = rf.fit$mse[length(rf.fit$mse)],
       Var_exp = rf.fit$rsq[length(rf.fit$rsq)],
       MSE_test = rf.fit$test$mse[length(rf.fit$test$mse)],
       Var_exp_test = rf.fit$test$rsq[length(rf.fit$test$rsq)],
       R2 = r_sq$r.squared, 
       mse_test = mse_test$mse,
       site = "Agriculture",
       property = "Ni")

if(file.exists(here::here("./notebooks/model_performance_data.csv"))) {
  model_performance_final <- read_csv(here::here("./notebooks/model_performance_data.csv")) |>
    rows_upsert(model_performance, by = c("site", "property")) 
} else model_performance_final <- model_performance
Rows: 22 Columns: 8
── Column specification ────────────────────────────────────────────────────────
Delimiter: ","
chr (2): site, property
dbl (6): MSE, Var_exp, MSE_test, Var_exp_test, R2, mse_test

ℹ Use `spec()` to retrieve the full column specification for this data.
ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
write_csv(model_performance_final, here::here("./notebooks/model_performance_data.csv"))

Predict original data

In [19]:
orig_data <- read_csv(here::here("./notebooks/ag_data_49.csv"))  %>%
  select("ni", any_of(attribute))
Rows: 49 Columns: 22
── Column specification ────────────────────────────────────────────────────────
Delimiter: ","
dbl (22): plan_curvature, profile_curvature, saga_wetness_index, catchment_a...

ℹ Use `spec()` to retrieve the full column specification for this data.
ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
prediction_org <- predict(rf.fit, newdata = orig_data)
test_plot_org <- data.frame(pred = prediction_org, obs = orig_data$ni)

r_sq_org <- summary(lm(pred~obs, data = test_plot_org))
r_sq_org

Call:
lm(formula = pred ~ obs, data = test_plot_org)

Residuals:
     Min       1Q   Median       3Q      Max 
-0.82513 -0.18294 -0.01255  0.21904  2.20621 

Coefficients:
            Estimate Std. Error t value Pr(>|t|)    
(Intercept)  1.38224    0.93000   1.486    0.144    
obs          0.95391    0.03131  30.471   <2e-16 ***
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Residual standard error: 0.4953 on 47 degrees of freedom
Multiple R-squared:  0.9518,    Adjusted R-squared:  0.9508 
F-statistic: 928.5 on 1 and 47 DF,  p-value: < 2.2e-16
ggplot(data = test_plot_org, aes(y = pred, x = obs)) +
  geom_point() +
  ggpmisc::stat_poly_line() +
  ggpmisc::stat_poly_eq() +
  theme_bw() +
  geom_abline(slope = 1, intercept = 0) +
  coord_fixed(ratio = 1)

mse_org <- test_plot_org %>%
  summarise(mse = mean((obs - pred)^2))

Write data

In [20]:
model_performance_org <- tibble(MSE = mse_org$mse,
       R2 = r_sq_org$r.squared, 
       site = "Agriculture",
       property = "ni")

if(file.exists(here::here("./notebooks/model_performance_data_49.csv"))) {
  model_performance_final_org <- read_csv(here::here("./notebooks/model_performance_data_49.csv")) |>
    rows_upsert(model_performance_org, by = c("site", "property")) 
} else model_performance_final_org <- model_performance_org
Rows: 24 Columns: 4
── Column specification ────────────────────────────────────────────────────────
Delimiter: ","
chr (2): site, property
dbl (2): MSE, R2

ℹ Use `spec()` to retrieve the full column specification for this data.
ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
write_csv(model_performance_final_org, here::here("./notebooks/model_performance_data_49.csv"))